diff --git a/common/chat.cpp b/common/chat.cpp index 8587140e1ff0a..63583fb22489d 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -9,8 +9,11 @@ #include #include +#include #include +#include #include +#include #include #include #include @@ -640,6 +643,7 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS"; case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2"; case COMMON_CHAT_FORMAT_APERTUS: return "Apertus"; + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools"; default: throw std::runtime_error("Unknown chat format"); } @@ -986,6 +990,126 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat return data; } + +// Case-insensitive find +static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) { + auto it = std::search( + haystack.begin() + pos, haystack.end(), + needle.begin(), needle.end(), + [](char a, char b) { return std::tolower(a) == std::tolower(b); } + ); + return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it); +} + +static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + const auto is_json_schema_provided = !inputs.json_schema.is_null(); + const auto is_grammar_provided = !inputs.grammar.empty(); + const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty(); + + // the logic requires potentially modifying the messages + auto tweaked_messages = inputs.messages; + + auto replace_json_schema_marker = [](json & messages) -> bool { + static std::string marker1 = "force json schema.\n"; + static std::string marker2 = "force json schema."; + + if (messages.empty() || messages.at(0).at("role") != "system") { + return false; + } + + std::string content = messages.at(0).at("content"); + + for (const auto & marker : {marker1, marker2}) { + const auto pos = ifind_string(content, marker); + if (pos != std::string::npos) { + content.replace(pos, marker.length(), ""); + // inject modified content back into the messages + messages.at(0).at("content") = content; + return true; + } + } + + return false; + }; + + // Lfm2 model does not natively work with json, but can generally understand the tools structure + // + // Example of the pytorch dialog structure: + // <|startoftext|><|im_start|>system + // List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|> + // <|im_start|>user + // What is the current status of candidate ID 12345?<|im_end|> + // <|im_start|>assistant + // <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|> + // <|im_start|>tool + // <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|> + // <|im_start|>assistant + // The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|> + // + // For the llama server compatibility with json tools semantic, + // the client can add "Follow json schema." line into the system message prompt to force the json output. + // + if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) { + // server/utils.hpp prohibits that branch for the custom grammar anyways + throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar"); + } else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) { + LOG_INF("%s: Using tools to build a grammar\n", __func__); + + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + schemas.push_back({ + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", function.at("name")}, + }}, + {"arguments", function.at("parameters")}, + }}, + {"required", json::array({"name", "arguments", "id"})}, + }); + }); + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + + builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\""); + }); + // model has no concept of tool selection mode choice, + // if the system prompt rendered correctly it will produce a tool call + // the grammar goes inside the tool call body + data.grammar_lazy = true; + data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}}; + data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"}; + data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS; + } else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) { + LOG_INF("%s: Using tools without json schema or grammar\n", __func__); + // output those tokens + data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"}; + } else if (is_json_schema_provided) { + LOG_INF("%s: Using provided json schema to build a grammar\n", __func__); + data.grammar = json_schema_to_grammar(inputs.json_schema); + } else if (is_grammar_provided) { + LOG_INF("%s: Using provided grammar\n", __func__); + data.grammar = inputs.grammar; + } else { + LOG_INF("%s: Using content relying on the template\n", __func__); + } + + data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); + LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str()); + + return data; +} + static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -2499,6 +2623,71 @@ static void common_chat_parse_apertus(common_chat_msg_parser & builder) { builder.add_content(builder.consume_rest()); } + +static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> + static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); + static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); + + // Loop through all tool calls + while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { + builder.move_to(res->groups[0].end); + + // Parse JSON array format: [{"name": "...", "arguments": {...}}] + auto tool_calls_data = builder.consume_json(); + + // Consume end marker + builder.consume_spaces(); + if (!builder.try_consume_regex(tool_call_end_regex)) { + throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); + } + + // Process each tool call in the array + if (tool_calls_data.json.is_array()) { + for (const auto & tool_call : tool_calls_data.json) { + if (!tool_call.is_object()) { + throw common_chat_msg_partial_exception("Tool call must be an object"); + } + + if (!tool_call.contains("name")) { + throw common_chat_msg_partial_exception("Tool call missing 'name' field"); + } + + std::string function_name = tool_call.at("name"); + std::string arguments = "{}"; + + if (tool_call.contains("arguments")) { + if (tool_call.at("arguments").is_object()) { + arguments = tool_call.at("arguments").dump(); + } else if (tool_call.at("arguments").is_string()) { + arguments = tool_call.at("arguments"); + } + } + + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + } else { + throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); + } + + // Consume any trailing whitespace after this tool call + builder.consume_spaces(); + } + + // Consume any remaining content after all tool calls + auto remaining = builder.consume_rest(); + if (!string_strip(remaining).empty()) { + builder.add_content(remaining); + } +} + static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { // Parse thinking tags first - this handles the main reasoning content builder.try_parse_reasoning("", ""); @@ -2748,6 +2937,12 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_apertus(tmpl, params); } + // LFM2 (w/ tools) + if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos && + src.find("]<|tool_list_end|>") != std::string::npos) { + return common_chat_params_init_lfm2(tmpl, params); + } + // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. if ((params.tools.is_array() && params.json_schema.is_object())) { @@ -2926,6 +3121,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_APERTUS: common_chat_parse_apertus(builder); break; + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: + common_chat_parse_lfm2(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } diff --git a/common/chat.h b/common/chat.h index f7b36ec711df4..50efb0d4e516f 100644 --- a/common/chat.h +++ b/common/chat.h @@ -116,6 +116,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_SEED_OSS, COMMON_CHAT_FORMAT_NEMOTRON_V2, COMMON_CHAT_FORMAT_APERTUS, + COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 093f2ab467f4d..b759366684396 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2460,18 +2460,21 @@ def set_gguf_parameters(self): ) class LlavaVisionModel(MmprojModel): img_break_tok_id = -1 + use_break_tok = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.hparams.get("model_type") == "pixtral": # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5) - self.img_break_tok_id = self.get_token_id("[IMG_BREAK]") + if self.use_break_tok: + self.img_break_tok_id = self.get_token_id("[IMG_BREAK]") elif self.is_mistral_format: # hparams is already vision config here so norm_eps is only defined in global_config. self.hparams["norm_eps"] = self.global_config.get("norm_eps", None) assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json" - self.img_break_tok_id = self.find_vparam(["image_break_token_id"]) + if self.use_break_tok: + self.img_break_tok_id = self.find_vparam(["image_break_token_id"]) else: raise ValueError(f"Unsupported model type: {self.hparams['model_type']}") logger.info(f"Image break token id: {self.img_break_tok_id}") @@ -3962,6 +3965,10 @@ def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor: return torch.stack([true_row, false_row], dim=0) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if "model.vision_" in name: + # skip multimodal tensors + return [] + if self.is_rerank: is_tied_head = self.is_tied_embeddings and "embed_tokens" in name is_real_head = not self.is_tied_embeddings and "lm_head" in name @@ -9435,6 +9442,21 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", " return super().map_tensor_name(name, try_suffixes) +@ModelBase.register("LightOnOCRForConditionalGeneration") +class LightOnOCRVisionModel(LlavaVisionModel): + is_mistral_format = False + use_break_tok = False + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LIGHTONOCR) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + name = name.replace("model.vision_encoder.", "vision_tower.") + name = name.replace("model.vision_projection.", "multi_modal_projector.") + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("KimiVLForConditionalGeneration") class KimiVLModel(MmprojModel): def __init__(self, *args, **kwargs): diff --git a/docs/build.md b/docs/build.md index dcbcce7549ad2..b410c710e30d3 100644 --- a/docs/build.md +++ b/docs/build.md @@ -261,10 +261,12 @@ You can download it from your Linux distro's package manager or from here: [ROCm - Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU): ```bash HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \ - cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ + cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ && cmake --build build --config Release -- -j 16 ``` + Note: `GPU_TARGETS` is optional, omitting it will build the code for all GPUs in the current system. + To enhance flash attention performance on RDNA3+ or CDNA architectures, you can utilize the rocWMMA library by enabling the `-DGGML_HIP_ROCWMMA_FATTN=ON` option. This requires rocWMMA headers to be installed on the build system. The rocWMMA library is included by default when installing the ROCm SDK using the `rocm` meta package provided by AMD. Alternatively, if you are not using the meta package, you can install the library using the `rocwmma-dev` or `rocwmma-devel` package, depending on your system's package manager. @@ -282,17 +284,17 @@ You can download it from your Linux distro's package manager or from here: [ROCm ```bash HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -p)" \ HIP_DEVICE_LIB_PATH= \ - cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ + cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ && cmake --build build -- -j 16 ``` - Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU): ```bash set PATH=%HIP_PATH%\bin;%PATH% - cmake -S . -B build -G Ninja -DAMDGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release + cmake -S . -B build -G Ninja -DGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release cmake --build build ``` - Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors) + If necessary, adapt `GPU_TARGETS` to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors) Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`. diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index b52f0f8472cfe..3156bd60101d7 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7519,8 +7519,8 @@ static void ggml_compute_forward_upscale_f32( float pixel_offset = 0.5f; if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { pixel_offset = 0.0f; - sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1); - sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1); + sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; + sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; } for (int64_t i3 = 0; i3 < ne3; i3++) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 6b688bfecdedd..94ab1ec0f5a90 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2976,7 +2976,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, if (ops.size() == topk_moe_ops_with_norm.size() && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 8 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; - ggml_tensor * weights = cgraph->nodes[node_idx+8]; + ggml_tensor * weights = cgraph->nodes[node_idx + 9]; if (ggml_cuda_should_use_topk_moe(softmax, weights)) { return true; @@ -2986,7 +2986,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, if (ops.size() == topk_moe_ops.size() && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; - ggml_tensor * weights = cgraph->nodes[node_idx+4]; + ggml_tensor * weights = cgraph->nodes[node_idx + 4]; if (ggml_cuda_should_use_topk_moe(softmax, weights)) { return true; } @@ -3125,17 +3125,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx if (!disable_fusion) { if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) { - ggml_tensor * weights = cgraph->nodes[i+8]; - ggml_tensor * selected_experts = cgraph->nodes[i+3]; + ggml_tensor * weights = cgraph->nodes[i + 9]; + ggml_tensor * selected_experts = cgraph->nodes[i + 3]; + ggml_tensor * clamp = cgraph->nodes[i + 7]; ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true, - /*delayed softmax*/ false); - i += 8; + /*delayed softmax*/ false, clamp); + i += 9; continue; } if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { - ggml_tensor * weights = cgraph->nodes[i+4]; - ggml_tensor * selected_experts = cgraph->nodes[i+3]; + ggml_tensor * weights = cgraph->nodes[i + 4]; + ggml_tensor * selected_experts = cgraph->nodes[i + 3]; ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false, /*delayed softmax*/ false); i += 4; diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index e28c810ac5df7..572379fcbf0e8 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -2,6 +2,7 @@ #include "ggml.h" #include "topk-moe.cuh" +#include #include // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. @@ -63,7 +64,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * float * weights, int32_t * ids, const int n_rows, - const int n_expert_used) { + const int n_expert_used, + const float clamp_val) { const int row = blockIdx.x * blockDim.y + threadIdx.y; if (row >= n_rows) { return; @@ -139,6 +141,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * if constexpr (with_norm) { wt_sum = warp_reduce_sum(wt_sum); + wt_sum = max(wt_sum, clamp_val); const float inv_sum = 1.0f / wt_sum; for (int i = 0; i < experts_per_thread; i++) { @@ -157,6 +160,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * weights[idx] = output_weights[i]; } } + + if (!with_norm) { + GGML_UNUSED(clamp_val); + } } template @@ -166,9 +173,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, int32_t * ids, const int n_rows, const int n_expert, - const int n_expert_used) { + const int n_expert_used, + const float clamp_val) { static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization"); - const int rows_per_block = 4; dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1); @@ -177,43 +184,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, switch (n_expert) { case 1: topk_moe_cuda<1, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 2: topk_moe_cuda<2, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 4: topk_moe_cuda<4, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 8: topk_moe_cuda<8, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 16: topk_moe_cuda<16, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 32: topk_moe_cuda<32, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 64: topk_moe_cuda<64, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 128: topk_moe_cuda<128, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 256: topk_moe_cuda<256, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 512: topk_moe_cuda<512, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; default: GGML_ASSERT(false && "fatal error"); @@ -226,7 +233,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, ggml_tensor * weights, ggml_tensor * ids, const bool with_norm, - const bool delayed_softmax) { + const bool delayed_softmax, + ggml_tensor * clamp) { GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -242,18 +250,25 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const int n_expert_used = weights->ne[1]; + float clamp_val = -INFINITY; if (with_norm) { - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + if (clamp) { + clamp_val = ggml_get_op_params_f32(clamp, 0); + } + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val); } else { + GGML_ASSERT(clamp == nullptr); if (delayed_softmax) { - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, + clamp_val); } else { - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, + clamp_val); } } } -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) { +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) { float scale = 1.0f; float max_bias = 0.0f; @@ -279,13 +294,26 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso return false; } + if (clamp) { + if (clamp->op != GGML_OP_CLAMP) { + return false; + } + float max_val = ggml_get_op_params_f32(clamp, 1); + + if (max_val != INFINITY) { + return false; + } + } + + return true; } std::initializer_list ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) { static std::initializer_list norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, - GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; + GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, + GGML_OP_RESHAPE }; static std::initializer_list no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS }; diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh index cc2fbfe9e6649..2eff408b03058 100644 --- a/ggml/src/ggml-cuda/topk-moe.cuh +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -8,8 +8,9 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, ggml_tensor * weights, ggml_tensor * ids, const bool with_norm, - const bool delayed_softmax = false); + const bool delayed_softmax = false, + ggml_tensor * weight_clamp = nullptr); -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights); +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr); std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false); diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu index ef48aa5f97bcd..35b7e61d80ac9 100644 --- a/ggml/src/ggml-cuda/upscale.cu +++ b/ggml/src/ggml-cuda/upscale.cu @@ -126,8 +126,8 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } else if (mode == GGML_SCALE_MODE_BILINEAR) { float pixel_offset = 0.5f; if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { - sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1); - sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1); + sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0; + sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1; pixel_offset = 0.0f; } upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 6b499320e7b12..23b6889919f20 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -29,10 +29,11 @@ if (CXX_IS_HIPCC) endif() else() # Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES. + if(AMDGPU_TARGETS AND NOT GPU_TARGETS) + set(GPU_TARGETS ${AMDGPU_TARGETS}) + endif() if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS}) - elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) - set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS}) endif() cmake_minimum_required(VERSION 3.21) enable_language(HIP) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index db33a4ab6c2e3..93a3600b63f07 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -6156,8 +6156,8 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3)); } else if (mode == GGML_SCALE_MODE_BILINEAR) { if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { - sf0 = (float)(ne0 - 1) / (ne00 - 1); - sf1 = (float)(ne1 - 1) / (ne01 - 1); + sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; + sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; pixel_offset = 0.0f; } diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index b1575b8145138..ca53f3e90068c 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -32,6 +32,7 @@ #include "pad.hpp" #include "quantize.hpp" #include "quants.hpp" +#include "roll.hpp" #include "rope.hpp" #include "set_rows.hpp" #include "softmax.hpp" diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index b695ba051b025..62d0ecd94ee0a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -48,6 +48,7 @@ #include "ggml-sycl/set.hpp" #include "ggml-sycl/sycl_hw.hpp" #include "ggml-sycl/getrows.hpp" +#include "ggml-sycl/repeat_back.hpp" #include "ggml-sycl/quantize.hpp" #include "ggml.h" @@ -2615,6 +2616,10 @@ catch (sycl::exception const &exc) { std::exit(1); } +static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_repeat_back(ctx, dst); +} static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); @@ -3679,6 +3684,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_REPEAT: ggml_sycl_repeat(ctx, dst); break; + case GGML_OP_REPEAT_BACK: + ggml_sycl_repeat_back(ctx, dst); + break; case GGML_OP_GET_ROWS: ggml_sycl_get_rows(ctx, dst); break; @@ -3913,6 +3921,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_sycl_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_ROLL: + ggml_sycl_roll(ctx, dst); + break; case GGML_OP_ARANGE: ggml_sycl_arange(ctx, dst); break; @@ -4516,6 +4527,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g ggml_type src0_type = op->src[0]->type; return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; } + case GGML_OP_REPEAT_BACK: + { + ggml_type src0_type = op->src[0]->type; + return src0_type == GGML_TYPE_F32; + } case GGML_OP_DUP: case GGML_OP_ARGMAX: case GGML_OP_NONE: @@ -4586,6 +4602,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_RWKV_WKV7: case GGML_OP_GATED_LINEAR_ATTN: return true; + case GGML_OP_ROLL: + return op->type == GGML_TYPE_F32; case GGML_OP_ARANGE: return op->type == GGML_TYPE_F32; default: diff --git a/ggml/src/ggml-sycl/repeat_back.cpp b/ggml/src/ggml-sycl/repeat_back.cpp new file mode 100644 index 0000000000000..abcd4cee72a48 --- /dev/null +++ b/ggml/src/ggml-sycl/repeat_back.cpp @@ -0,0 +1,56 @@ +#include "repeat_back.hpp" + +#include "common.hpp" + +void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const float * src0_dd = (const float *) dst->src[0]->data; + float * dst_dd = (float *) dst->data; + + const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3]; + const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2], + ne03 = dst->src[0]->ne[3]; + + const int nr0 = (int) (ne00 / ne0); + const int nr1 = (int) (ne01 / ne1); + const int nr2 = (int) (ne02 / ne2); + const int nr3 = (int) (ne03 / ne3); + + const size_t total = ne0 * ne1 * ne2 * ne3; + const int BLOCK_SIZE = 256; + const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE; + + queue_ptr stream = ctx.stream(); + + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + const size_t i = item_ct1.get_global_linear_id(); + if (i >= total) { + return; + } + + const int i0 = i % ne0; + const int i1 = (i / ne0) % ne1; + const int i2 = (i / (ne0 * ne1)) % ne2; + const int i3 = i / (ne0 * ne1 * ne2); + + float acc = 0.0f; + + for (int j3 = 0; j3 < nr3; ++j3) { + for (int j2 = 0; j2 < nr2; ++j2) { + for (int j1 = 0; j1 < nr1; ++j1) { + for (int j0 = 0; j0 < nr0; ++j0) { + acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 + + (i3 + j3 * ne3) * ne00 * ne01 * ne02]; + } + } + } + } + + dst_dd[i] = acc; + }); +} diff --git a/ggml/src/ggml-sycl/repeat_back.hpp b/ggml/src/ggml-sycl/repeat_back.hpp new file mode 100644 index 0000000000000..17a87f3e159b3 --- /dev/null +++ b/ggml/src/ggml-sycl/repeat_back.hpp @@ -0,0 +1,8 @@ +#ifndef GGML_SYCL_REPEAT_BACK_HPP +#define GGML_SYCL_REPEAT_BACK_HPP + +#include "common.hpp" + +void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_REPEAT_BACK_HPP diff --git a/ggml/src/ggml-sycl/roll.cpp b/ggml/src/ggml-sycl/roll.cpp new file mode 100644 index 0000000000000..1e05181789c28 --- /dev/null +++ b/ggml/src/ggml-sycl/roll.cpp @@ -0,0 +1,122 @@ +#include "roll.hpp" +#include "common.hpp" + +using namespace sycl; + +static inline int wrap_add(int i, int shift, int n) { + + int s = i + shift; + return (s >= n) ? (s - n) : s; +} + +static void kernel_roll_fused_i0_i1( + queue &q, + const float *src_d, + float *dst_d, + int ne0, int ne1, int ne2, int ne3, + int sh0, int sh1, int sh2, int sh3) +{ + if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return; + + + const int stride1 = ne0; + const int stride2 = ne0 * ne1; + const int stride3 = ne0 * ne1 * ne2; + + + const int shNe0 = (ne0 - sh0) % ne0; + const int shNe1 = (ne1 - sh1) % ne1; + const int shNe2 = (ne2 - sh2) % ne2; + const int shNe3 = (ne3 - sh3) % ne3; + + + const size_t g0 = (size_t) ne3; + const size_t g1 = (size_t) ne2; + const size_t g2 = (size_t) (ne1 * ne0); + + const range<3> global{ g0, g1, g2 }; + + q.submit([&](handler &h) { + h.parallel_for(global, [=](id<3> idx) { + const int i3 = (int) idx[0]; + const int i2 = (int) idx[1]; + + const int fused = (int) idx[2]; + const int i1 = fused / ne0; + const int i0 = fused - i1 * ne0; // fused % ne0 + + + const int idx_dst = i0 + + i1 * stride1 + + i2 * stride2 + + i3 * stride3; + + + const int s0 = wrap_add(i0, shNe0, ne0); + const int s1 = wrap_add(i1, shNe1, ne1); + const int s2 = wrap_add(i2, shNe2, ne2); + const int s3 = wrap_add(i3, shNe3, ne3); + + const int idx_src = s0 + + s1 * stride1 + + s2 * stride2 + + s3 * stride3; + + dst_d[idx_dst] = src_d[idx_src]; + }); + }); +} + +void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const ggml_tensor *src = dst->src[0]; + GGML_ASSERT(src && src->type == GGML_TYPE_F32); + + const int ne0 = (int) dst->ne[0]; + const int ne1 = (int) dst->ne[1]; + const int ne2 = (int) dst->ne[2]; + const int ne3 = (int) dst->ne[3]; + + const int32_t *params = (const int32_t *) dst->op_params; + int shift0 = params[0]; + int shift1 = params[1]; + int shift2 = params[2]; + int shift3 = params[3]; + + + if ((shift0 | shift1 | shift2 | shift3) == 0) { + const size_t nb = ggml_nbytes(src); + queue *q = ctx.stream(); + SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb))); + return; + } + + auto norm = [](int sh, int n) -> int { + if (n <= 0) return 0; + sh %= n; + if (sh < 0) sh += n; + return sh; + }; + shift0 = norm(shift0, ne0); + shift1 = norm(shift1, ne1); + shift2 = norm(shift2, ne2); + shift3 = norm(shift3, ne3); + + try { + queue *q = ctx.stream(); + + const float *src_d = (const float *) src->data; + float *dst_d = (float *) dst->data; + GGML_ASSERT(src_d && dst_d); + + kernel_roll_fused_i0_i1( + *q, src_d, dst_d, + ne0, ne1, ne2, ne3, + shift0, shift1, shift2, shift3 + ); + } catch (const std::exception &e) { + std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what()); + throw; + } +} diff --git a/ggml/src/ggml-sycl/roll.hpp b/ggml/src/ggml-sycl/roll.hpp new file mode 100644 index 0000000000000..97dc03d64b24d --- /dev/null +++ b/ggml/src/ggml-sycl/roll.hpp @@ -0,0 +1,20 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_ROLL_HPP +#define GGML_SYCL_ROLL_HPP + +#include "common.hpp" + +void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst); + +#endif // GGML_SYCL_ROLL_HPP diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b783f7805e924..173677a2637a9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -523,7 +523,7 @@ struct vk_device_struct { vk_pipeline pipeline_add_id_f32; vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; - vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32; + vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32; vk_pipeline pipeline_scale_f32; vk_pipeline pipeline_sqr_f32; vk_pipeline pipeline_sqrt_f32; @@ -1238,6 +1238,7 @@ struct vk_op_upscale_push_constants { uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; float sf0; float sf1; float sf2; float sf3; + float pixel_offset; }; struct vk_op_sum_rows_push_constants @@ -3493,7 +3494,6 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1); ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1); - ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1); ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -7798,14 +7798,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_UPSCALE: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - int mode = ggml_get_op_params_i32(dst, 0); + ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(dst, 0) & 0xFF); switch (mode) { case GGML_SCALE_MODE_NEAREST: return ctx->device->pipeline_upscale_nearest_f32; case GGML_SCALE_MODE_BILINEAR: return ctx->device->pipeline_upscale_bilinear_f32; - case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS: - return ctx->device->pipeline_upscale_bilinear_ac_f32; + default: + return nullptr; } } return nullptr; @@ -9294,22 +9294,26 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0); - float sf0 = (float)dst->ne[0] / src0->ne[0]; - float sf1 = (float)dst->ne[1] / src0->ne[1]; - float sf2 = (float)dst->ne[2] / src0->ne[2]; - float sf3 = (float)dst->ne[3] / src0->ne[3]; + GGML_TENSOR_UNARY_OP_LOCALS + + float sf0 = (float)ne0 / ne00; + float sf1 = (float)ne1 / ne01; + float sf2 = (float)ne2 / ne02; + float sf3 = (float)ne3 / ne03; + float pixel_offset = 0.5f; if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) { - sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1); - sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1); + sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; + sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; + pixel_offset = 0.0f; } ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { (uint32_t)ggml_nelements(dst), 0, 0, - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], - (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], - sf0, sf1, sf2, sf3, + (uint32_t)ne00, (uint32_t)ne01, + (uint32_t)nb00 / src0_type_size, (uint32_t)nb01 / src0_type_size, (uint32_t)nb02 / src0_type_size, (uint32_t)nb03 / src0_type_size, + (uint32_t)ne0, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, + sf0, sf1, sf2, sf3, pixel_offset }, dryrun); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp index 154a2172d83db..8670aad32c380 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -7,6 +7,7 @@ layout (push_constant) uniform parameter uint nb00; uint nb01; uint nb02; uint nb03; uint ne10; uint ne11; uint ne12; uint ne13; float sf0; float sf1; float sf2; float sf3; + float pixel_offset; } p; #include "types.glsl" @@ -19,7 +20,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; // from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag #define NEAREST 0 #define BILINEAR 1 -#define ALIGN_CORNERS (1 << 8) layout (constant_id = 0) const uint scale_mode = 0; @@ -52,7 +52,7 @@ float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) { float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) { const ivec2 ne0 = ivec2(p.ne00, p.ne01); - const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5; + const vec2 c = (vec2(i10, i11) + p.pixel_offset) / vec2(p.sf0, p.sf1) - p.pixel_offset; const vec2 c0f = floor(c); const vec2 d = c - c0f; const ivec2 c0 = max(ivec2(c0f), 0); @@ -61,16 +61,6 @@ float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) { return fetch_bilinear(c0, c1, d, i12, i13); } -float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) { - const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1); - const vec2 c0f = floor(c); - const vec2 d = c - c0f; - const ivec2 c0 = ivec2(c0f); - const ivec2 c1 = c0 + 1; - - return fetch_bilinear(c0, c1, d, i12, i13); -} - void main() { const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; @@ -91,9 +81,6 @@ void main() { case BILINEAR: result = interpolate_bilinear(i10, i11, i12, i13); break; - case BILINEAR | ALIGN_CORNERS: - result = interpolate_bilinear_align_corners(i10, i11, i12, i13); - break; } data_d[p.d_offset + idx] = D_TYPE(result); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1b71fb3749aaa..94fcfaf69cf09 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -3062,6 +3062,7 @@ class VisionProjectorType: VOXTRAL = "voxtral" LFM2 = "lfm2" KIMIVL = "kimivl" + LIGHTONOCR = "lightonocr" # Items here are (block size, type size) diff --git a/models/templates/llama-cpp-lfm2.jinja b/models/templates/llama-cpp-lfm2.jinja new file mode 100644 index 0000000000000..b7921120bc007 --- /dev/null +++ b/models/templates/llama-cpp-lfm2.jinja @@ -0,0 +1,37 @@ +{{- bos_token -}} +{%- set system_prompt = "" -%} +{%- set ns = namespace(system_prompt="") -%} +{%- if messages[0]["role"] == "system" -%} + {%- set ns.system_prompt = messages[0]["content"] -%} + {%- set messages = messages[1:] -%} +{%- endif -%} +{%- if tools -%} + {%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: <|tool_list_start|>[" -%} + {%- for tool in tools -%} + {%- if tool is not string -%} + {%- set tool = tool | tojson -%} + {%- endif -%} + {%- set ns.system_prompt = ns.system_prompt + tool -%} + {%- if not loop.last -%} + {%- set ns.system_prompt = ns.system_prompt + ", " -%} + {%- endif -%} + {%- endfor -%} + {%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%} +{%- endif -%} +{%- if ns.system_prompt -%} + {{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}} +{%- endif -%} +{%- for message in messages -%} + {{- "<|im_start|>" + message["role"] + "\n" -}} + {%- set content = message["content"] -%} + {%- if content is not string -%} + {%- set content = content | tojson -%} + {%- endif -%} + {%- if message["role"] == "tool" -%} + {%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%} + {%- endif -%} + {{- content + "<|im_end|>\n" -}} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{- "<|im_start|>assistant\n" -}} +{%- endif -%} diff --git a/src/llama-context.cpp b/src/llama-context.cpp index bd348bcad370a..f6192a36e0ee5 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -268,9 +268,7 @@ llama_context::llama_context( if (pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); } - } - if (!hparams.vocab_only) { llama_memory_context_ptr mctx; if (memory) { LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); @@ -343,7 +341,14 @@ llama_context::llama_context( { auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); + if (pipeline_parallel) { + LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); + gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + } + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } } n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b88ff51f5da1d..05e467180089e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -15,7 +15,6 @@ #include #include -#include #include #include #include @@ -438,7 +437,7 @@ struct llama_model::impl { llama_mlocks mlock_mmaps; // contexts where the model tensors metadata is stored as well ass the corresponding buffers: - std::vector> ctxs_bufs; + std::vector>> ctxs_bufs; buft_list_t cpu_buft_list; std::map gpu_buft_list; @@ -6186,7 +6185,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr; bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev); - ggml_backend_buffer_t buf = nullptr; + std::vector bufs; if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) { for (uint32_t idx = 0; idx < ml.files.size(); idx++) { // only the mmap region containing the tensors in the model is mapped to the backend buffer @@ -6199,15 +6198,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { continue; } const size_t max_size = ggml_get_max_tensor_size(ctx); - buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); + ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); if (buf == nullptr) { throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); } + bufs.emplace_back(buf); buf_map.emplace(idx, buf); } } else { - buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (buf == nullptr) { throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); } @@ -6217,11 +6217,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { mlock_buf->init (ggml_backend_buffer_get_base(buf)); mlock_buf->grow_to(ggml_backend_buffer_get_size(buf)); } + bufs.emplace_back(buf); for (uint32_t idx = 0; idx < ml.files.size(); idx++) { buf_map.emplace(idx, buf); } } - pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), buf); + pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs)); for (auto & buf : buf_map) { // indicate that this buffer contains weights @@ -6247,8 +6248,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // print memory requirements per buffer type - for (auto & [_, buf] : pimpl->ctxs_bufs) { - LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); + for (auto & [_, bufs] : pimpl->ctxs_bufs) { + for (auto & buf: bufs) { + LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", + __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); + } } // populate tensors_by_name @@ -6300,8 +6304,10 @@ size_t llama_model::n_devices() const { std::map llama_model::memory_breakdown() const { std::map ret; - for (const auto & [_, buf] : pimpl->ctxs_bufs) { - ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); + for (const auto & [_, bufs] : pimpl->ctxs_bufs) { + for (const auto & buf : bufs) { + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); + } } return ret; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0ad73e944edac..aee1730137900 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -511,7 +511,7 @@ struct test_result { }; // Printer classes for different output formats -enum class test_status_t { NOT_SUPPORTED, OK, FAIL }; +enum class test_status_t { NOT_SUPPORTED, OK, FAIL, SKIPPED }; struct test_operation_info { std::string op_name; @@ -687,6 +687,8 @@ struct printer { virtual void print_backend_status(const backend_status_info & info) { (void) info; } virtual void print_overall_summary(const overall_summary_info & info) { (void) info; } + + virtual void print_failed_tests(const std::vector & failed_tests) { (void) failed_tests; } }; struct console_printer : public printer { @@ -804,6 +806,17 @@ struct console_printer : public printer { } } + void print_failed_tests(const std::vector & failed_tests) override { + if (failed_tests.empty()) { + return; + } + + printf("\nFailing tests:\n"); + for (const auto & test_name : failed_tests) { + printf(" %s\n", test_name.c_str()); + } + } + private: void print_test_console(const test_result & result) { printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str()); @@ -1056,6 +1069,8 @@ struct test_case { std::vector sentinels; + std::string current_op_name; + void add_sentinel(ggml_context * ctx) { if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) { return; @@ -1127,7 +1142,10 @@ struct test_case { } } - bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_names_filter, printer * output_printer) { + test_status_t eval(ggml_backend_t backend1, + ggml_backend_t backend2, + const char * op_names_filter, + printer * output_printer) { mode = MODE_TEST; ggml_init_params params = { @@ -1144,11 +1162,12 @@ struct test_case { add_sentinel(ctx); ggml_tensor * out = build_graph(ctx); - std::string current_op_name = op_desc(out); + current_op_name = op_desc(out); + if (!matches_filter(out, op_names_filter)) { //printf(" %s: skipping\n", op_desc(out).c_str()); ggml_free(ctx); - return true; + return test_status_t::SKIPPED; } // check if the backends support the ops @@ -1172,7 +1191,7 @@ struct test_case { } ggml_free(ctx); - return true; + return test_status_t::NOT_SUPPORTED; } // post-graph sentinel @@ -1184,7 +1203,7 @@ struct test_case { if (buf == NULL) { printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1)); ggml_free(ctx); - return false; + return test_status_t::FAIL; } // build graph @@ -1289,7 +1308,7 @@ struct test_case { output_printer->print_test_result(result); } - return test_passed; + return test_passed ? test_status_t::OK : test_status_t::FAIL; } bool eval_perf(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) { @@ -1306,7 +1325,7 @@ struct test_case { GGML_ASSERT(ctx); ggml_tensor * out = build_graph(ctx.get()); - std::string current_op_name = op_desc(out); + current_op_name = op_desc(out); if (!matches_filter(out, op_names_filter)) { //printf(" %s: skipping\n", op_desc(out).c_str()); return true; @@ -1435,8 +1454,9 @@ struct test_case { ggml_context_ptr ctx(ggml_init(params)); // smart ptr GGML_ASSERT(ctx); - ggml_tensor * out = build_graph(ctx.get()); - std::string current_op_name = op_desc(out); + ggml_tensor * out = build_graph(ctx.get()); + current_op_name = op_desc(out); + if (!matches_filter(out, op_names_filter)) { return true; } @@ -4712,6 +4732,7 @@ struct test_topk_moe: public test_case { out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens); ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens] + weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY); out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens] out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens); } @@ -7028,6 +7049,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {5, 7, 11, 13}, {2, 5, 7, 11}, mode)); } test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)); + test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {1, 4, 3, 2}, {2, 8, 3, 2}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)); + test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {4, 1, 3, 2}, {1, 1, 3, 2}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)); test_cases.emplace_back(new test_sum()); test_cases.emplace_back(new test_sum_rows()); @@ -7359,16 +7382,26 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } size_t n_ok = 0; + size_t tests_run = 0; + std::vector failed_tests; for (auto & test : test_cases) { - if (test->eval(backend, backend_cpu, op_names_filter, output_printer)) { + test_status_t status = test->eval(backend, backend_cpu, op_names_filter, output_printer); + if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) { + continue; + } + tests_run++; + if (status == test_status_t::OK) { n_ok++; + } else if (status == test_status_t::FAIL) { + failed_tests.push_back(test->current_op_name + "(" + test->vars() + ")"); } } - output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false)); + output_printer->print_summary(test_summary_info(n_ok, tests_run, false)); + output_printer->print_failed_tests(failed_tests); ggml_backend_free(backend_cpu); - return n_ok == test_cases.size(); + return n_ok == tests_run; } if (mode == MODE_GRAD) { diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 52e23b5ac61f5..4a8ba849b3f8c 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -16,6 +16,7 @@ #include #include +#include #include using json = nlohmann::ordered_json; @@ -2138,6 +2139,154 @@ static void test_template_output_parsers() { assert_equals(true, common_chat_templates_support_enable_thinking(tmpls.get())); } + { + // LFM2 format tests + auto tmpls = read_templates("models/templates/llama-cpp-lfm2.jinja"); + std::vector end_tokens{ "<|im_end|>" }; + + auto inputs_tools_forced_json_schema = std::invoke([&]() -> common_chat_templates_inputs { + common_chat_templates_inputs inputs; + inputs.messages = { + std::invoke([&]() -> common_chat_msg { + common_chat_msg msg; + msg.role = "system"; + msg.content = "force json schema.\n"; + return msg; + }), + message_user, + }; + inputs.tools = {special_function_tool}; + return inputs; + }); + + { + auto params = common_chat_templates_apply(tmpls.get(), inputs_no_tools); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params.format); + assert_equals(false, params.grammar_lazy); + assert_equals(std::string(R"(<|im_start|>user +Hey there!<|im_end|> +<|im_start|>assistant +)"), params.prompt); + } + + { + auto params = common_chat_templates_apply(tmpls.get(), inputs_tools); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params.format); + assert_equals(false, params.grammar_lazy); + assert_equals(std::string(R"(<|im_start|>system +List of tools: <|tool_list_start|>[{"type": "function", "function": {"name": "special_function", "description": "I'm special", "parameters": {"type": "object", "properties": {"arg1": {"type": "integer", "description": "The arg."}}, "required": ["arg1"]}}}]<|tool_list_end|><|im_end|> +<|im_start|>user +Hey there!<|im_end|> +<|im_start|>assistant +)"), params.prompt); + assert_equals(true, params.grammar.empty()); + } + + { + auto params = common_chat_templates_apply(tmpls.get(), inputs_tools_forced_json_schema); + assert_equals(COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, params.format); + assert_equals(true, params.grammar_lazy); + assert_equals(std::string(R"(<|im_start|>system +List of tools: <|tool_list_start|>[{"type": "function", "function": {"name": "special_function", "description": "I'm special", "parameters": {"type": "object", "properties": {"arg1": {"type": "integer", "description": "The arg."}}, "required": ["arg1"]}}}]<|tool_list_end|><|im_end|> +<|im_start|>user +Hey there!<|im_end|> +<|im_start|>assistant +)"), params.prompt); + assert_equals(false, params.grammar.empty()); + } + + // Test parsing regular content + assert_msg_equals(message_assist, + common_chat_parse( + "Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test single tool call with JSON format + common_chat_msg msg_single_tool_call; + msg_single_tool_call.role = "assistant"; + msg_single_tool_call.tool_calls.push_back({"special_function", "{\"arg1\":1}", ""}); + assert_msg_equals( + msg_single_tool_call, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with string argument + common_chat_msg msg_tool_call_string; + msg_tool_call_string.role = "assistant"; + msg_tool_call_string.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_tool_call_string, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with multiple arguments + common_chat_msg msg_multi_args; + msg_multi_args.role = "assistant"; + msg_multi_args.tool_calls.push_back({"calculate", "{\"x\":10,\"y\":20,\"operation\":\"add\"}", ""}); + assert_msg_equals( + msg_multi_args, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"calculate\", \"arguments\": {\"x\": 10, \"y\": 20, \"operation\": \"add\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test multiple tool calls in single array + common_chat_msg msg_multiple_tools; + msg_multiple_tools.role = "assistant"; + msg_multiple_tools.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + msg_multiple_tools.tool_calls.push_back({"get_time", "{\"timezone\":\"UTC\"}", ""}); + assert_msg_equals( + msg_multiple_tools, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}, {\"name\": \"get_time\", \"arguments\": {\"timezone\": \"UTC\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with content before + common_chat_msg msg_content_before_tool; + msg_content_before_tool.role = "assistant"; + msg_content_before_tool.content = "Let me check the weather for you."; + msg_content_before_tool.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_content_before_tool, + common_chat_parse( + "Let me check the weather for you.<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with content after + common_chat_msg msg_content_after_tool; + msg_content_after_tool.role = "assistant"; + msg_content_after_tool.content = "Here's the result."; + msg_content_after_tool.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_content_after_tool, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>Here's the result.", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with newlines (common in LLM output) + common_chat_msg msg_tool_call_newlines; + msg_tool_call_newlines.role = "assistant"; + msg_tool_call_newlines.tool_calls.push_back({"get_current_time", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_tool_call_newlines, + common_chat_parse( + "<|tool_call_start|>[{\n \"name\": \"get_current_time\",\n \"arguments\": {\n \"location\": \"Paris\"\n }\n}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Note: LFM2 uses JSON format for tool calls: [{"name": "...", "arguments": {...}}] + // Unlike other formats, LFM2 template does not render tool calls in conversation history, + // so we don't use test_templates() for tool call generation. Instead, the parsing tests + // above verify edge cases and format variations for the tool call output format. + } } diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 1669fad99b36b..ad2108d1798ae 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -139,6 +139,7 @@ enum projector_type { PROJECTOR_TYPE_VOXTRAL, PROJECTOR_TYPE_LFM2, PROJECTOR_TYPE_KIMIVL, + PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_UNKNOWN, }; @@ -161,6 +162,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, { PROJECTOR_TYPE_LFM2, "lfm2"}, { PROJECTOR_TYPE_KIMIVL, "kimivl"}, + { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f2abf88523843..b44f0a3a28ad2 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -171,7 +171,7 @@ struct clip_hparams { int32_t n_head; int32_t n_layer; // idefics3 - int32_t preproc_image_size = 0; + int32_t preproc_image_size = 0; // aka max_dimension int32_t proj_scale_factor = 0; float image_mean[3]; @@ -621,7 +621,7 @@ struct clip_graph { } // arrangement of the [IMG_BREAK] token - { + if (model.token_embd_img_break) { // not efficient, but works // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows] // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension @@ -2095,6 +2095,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 res = graph.build_siglip(); } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: { res = graph.build_pixtral(); } break; @@ -2380,6 +2381,7 @@ struct clip_model_loader { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false); } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: { hparams.rope_theta = 10000.0f; hparams.warmup_image_size = hparams.patch_size * 8; @@ -2722,6 +2724,15 @@ struct clip_model_loader { model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); } break; + case PROJECTOR_TYPE_LIGHTONOCR: + { + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); + model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); + } break; case PROJECTOR_TYPE_ULTRAVOX: { model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); @@ -3210,8 +3221,8 @@ struct image_manipulation { return {0, 0}; } - float scale = std::min(1.0f, std::min(static_cast(max_dimension) / inp_size.width, - static_cast(max_dimension) / inp_size.height)); + float scale = std::min(static_cast(max_dimension) / inp_size.width, + static_cast(max_dimension) / inp_size.height); float target_width_f = static_cast(inp_size.width) * scale; float target_height_f = static_cast(inp_size.height) * scale; @@ -3374,7 +3385,7 @@ struct llava_uhd { // resize to overview size clip_image_u8_ptr resized_img(clip_image_u8_init()); - image_manipulation::bicubic_resize(*img, *resized_img, inst.overview_size.width, inst.overview_size.height); + image_manipulation::resize_and_pad_image(*img, *resized_img, inst.overview_size); output.push_back(std::move(resized_img)); if (inst.slices.empty()) { // no slices, just return the resized image @@ -3576,6 +3587,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/image_processing_idefics3.py#L737 const clip_image_size refined_size = image_manipulation::calc_size_preserved_ratio( original_size, params.image_size, params.preproc_image_size); + // LOG_INF("%s: original size: %d x %d, refined size: %d x %d\n", + // __func__, original_size.width, original_size.height, + // refined_size.width, refined_size.height); llava_uhd::slice_instructions instructions; instructions.overview_size = clip_image_size{params.image_size, params.image_size}; @@ -3586,6 +3600,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str }; for (int y = 0; y < refined_size.height; y += params.image_size) { for (int x = 0; x < refined_size.width; x += params.image_size) { + // LOG_INF("%s: adding slice at x=%d, y=%d\n", __func__, x, y); instructions.slices.push_back(llava_uhd::slice_coordinates{ /* x */x, /* y */y, @@ -3622,7 +3637,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->entries.push_back(std::move(img_f32)); return true; - } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL + || ctx->proj_type() == PROJECTOR_TYPE_LIGHTONOCR + ) { clip_image_u8 resized_image; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size); image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height); @@ -3865,12 +3882,17 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches = x_patch * y_patch; } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: { // dynamic size int n_merge = params.spatial_merge_size; int n_patches_x = img->nx / patch_size / (n_merge > 0 ? n_merge : 1); int n_patches_y = img->ny / patch_size / (n_merge > 0 ? n_merge : 1); - n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row + if (ctx->model.token_embd_img_break) { + n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row + } else { + n_patches = n_patches_y * n_patches_x; + } } break; case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_ULTRAVOX: @@ -4247,6 +4269,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_LIGHTONOCR: { // set the 2D positions int n_patches_per_col = image_size_width / patch_size; @@ -4377,6 +4400,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_model_peg_0_b->ne[0]; case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_MLP_NORM: return ctx->model.mm_3_b->ne[0]; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 4d487581ae0a0..3b901bfac8215 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -275,6 +275,11 @@ struct mtmd_context { img_beg = ""; img_end = ""; + } else if (proj == PROJECTOR_TYPE_LIGHTONOCR) { + // <|im_start|> ... (image embeddings) ... <|im_end|> + img_beg = "<|im_start|>"; + img_end = "<|im_end|>"; + } } diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index dbdf7656a66d9..c2270746360ec 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -70,6 +70,7 @@ add_test_vision "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0" add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" add_test_vision "ggml-org/LFM2-VL-450M-GGUF:Q8_0" add_test_vision "ggml-org/granite-docling-258M-GGUF:Q8_0" +add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0" add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" @@ -138,7 +139,10 @@ for i in "${!arr_hf[@]}"; do echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log - if echo "$output" | grep -iq "new york"; then + # either contains "new york" or both "men" and "walk" + if echo "$output" | grep -iq "new york" \ + || (echo "$output" | grep -iq "men" && echo "$output" | grep -iq "walk") + then result="$prefix \033[32mOK\033[0m: $bin $hf" else result="$prefix \033[31mFAIL\033[0m: $bin $hf"